import normflows as nf
import torch

def create_invertible_glow_basic(layers=1):
    K = layers
    base = nf.distributions.base.DiagGaussian((4, 64, 64))

    channels = 4
    hidden_channels = 256
    split_mode = 'channel'
    scale = True

    # Set up flows, distributions and merge operations
    flows = []
    for j in range(K):
        flows.append(nf.flows.GlowBlock(channels, hidden_channels, split_mode=split_mode, scale=scale))

    model_glow = nf.NormalizingFlow(base, flows)
    return model_glow

def create_invertible_residual_basic(layers=1):
    # Define flows
    K = layers
    
    latent_size = (4, 64, 64)
    hidden_units = 64
    hidden_layers = 3

    flows = []
    for i in range(K):
        net = nf.nets.LipschitzCNN([4] + [hidden_units]*(hidden_layers - 1) + [4], [3, 1, 3], init_zeros=True, lipschitz_const=0.9)
        flows += [nf.flows.Residual(net)]

    # Set prior and q0
    q0 = nf.distributions.DiagGaussian(latent_size, trainable=False)
        
    # Construct flow model
    nfm = nf.NormalizingFlow(q0=q0, flows=flows)
    return nfm

def create_invertible_residual_complex(layers=2):
    K = layers
 
    latent_size = (8, 64, 64)
    hidden_units = 64
    hidden_layers = 3

    flows = []
    for i in range(K):
        net = nf.nets.LipschitzCNN([8] + [hidden_units]*(hidden_layers - 1) + [8], [3, 1, 3], init_zeros=True, lipschitz_const=0.9)
        flows += [nf.flows.Residual(net)]
        # flows += [nf.flows.ActNorm(latent_size)]

    # Set prior and q0
    q0 = nf.distributions.DiagGaussian((8,64,64), trainable=False)
        
    # Construct flow model
    nfm = nf.NormalizingFlow(q0=q0, flows=flows)
    return nfm

def create_invertible_real_nvp_basic(layers=1):
    # Define 2D Gaussian base distribution
    base = nf.distributions.base.DiagGaussian(4*64*64)

    # Define list of flows
    num_layers = layers
    flows = []
    for i in range(num_layers):
        # Neural network with two hidden layers having 64x64 units each
        # Last layer is initialized by zeros making training more stable
        param_map = nf.nets.MLP([2*64*64, 4*64*64], init_zeros=True)
        # Add flow layer
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        # Swap dimensions
        flows.append(nf.flows.Permute(4*64*64, mode='swap'))

    # Construct flow model
    model_real_nvp = nf.NormalizingFlow(base, flows)
    return model_real_nvp

def create_invertible_real_nvp_conv(layers=1):
    # Define 2D Gaussian base distribution
    base = nf.distributions.base.DiagGaussian((4, 64, 64))

    # Define list of flows
    num_layers = layers
    flows = []
    for i in range(num_layers):
        # Neural network with two hidden layers having 64x64 units each
        # Last layer is initialized by zeros making training more stable
        param_map = nf.nets.ConvResidualNet(2, 4, 64, num_blocks=2)
        # Add flow layer
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        # Swap dimensions
        flows.append(nf.flows.Permute(4, mode='swap'))

    # Construct flow model
    model_real_nvp = nf.NormalizingFlow(base, flows)

    return model_real_nvp

def create_mlp():
    # define mlp
    model = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(4*64*64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 4*64*64),
            torch.nn.Unflatten(1, (4, 64, 64))
        )
    return model
